-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 #155645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIRV] Add support for the SPIR-V extension SPV_KHR_bfloat16 #155645
Conversation
@llvm/pr-subscribers-llvm-globalisel @llvm/pr-subscribers-backend-spir-v Author: None (YixingZhang007) ChangesFull diff: https://github.com/llvm/llvm-project/pull/155645.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
NoUSWrap = 1 << 20, // Instruction supports geps
// no unsigned signed wrap.
SameSign = 1 << 21, // Both operands have the same sign.
- InBounds = 1 << 22 // Pointer arithmetic remains inbounds.
+ InBounds = 1 << 22, // Pointer arithmetic remains inbounds.
// Implies NoUSWrap.
+ BFloat16 = 1 << 23 // Instruction with bf16 type
};
private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
- return false;
+ // if (containsBF16Type(U))
+ // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (isa<FPMathOperator>(CI))
MIB->copyIRFlags(CI);
+ // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ if (containsBF16Type(U)) {
+ MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+ }
+
for (const auto &Arg : enumerate(CI.args())) {
// If this is required to be an immediate, don't materialize it in a
// register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
+ if (I.getType()->isBFloatTy())
+ MIFlags |= MachineInstr::MIFlag::BFloat16;
+
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
- SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+ SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+ {"SPV_KHR_bfloat16",
+ SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
+ if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ }
break;
}
case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
6de7136
to
95aa9a3
Compare
The |
32498d0
to
e0b7026
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bfloat type should be encoded like OpTypeFloat 16 0
, where 0 stands for FP encoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out the issue 🙂 I’ve updated the PR so that the Bfloat type is encoded using the SPIR-V instruction OpTypeFloat 16 0
, which is now distinct from the IEEE-754 float encoded as OpTypeFloat 16
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually can keep the enum and pass it as an argument. Reason: apart of IEEE-754 and bfloat types we will be introducing 8-bit floating point types (soon) and potentially other types later. So having the enum instead of the boolean flag would help here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the suggestion! I’ve created an enum class FPEncoding
to store the values of the supported FP encodings and it is then passed as an argument to function getOpTypeFloat
.
https://github.com/llvm/llvm-project/blob/8e9389ddcccbbce15e484cbdf0a89f27a3c07256/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td#L2009-L2028
The function getOpTypeFloat now has two interfaces: one that takes FPEncoding
as an argument, used when the float is not an IEEE-754 type, and another without FPEncoding
, used when the float is an IEEE-754 type.
https://github.com/llvm/llvm-project/blob/8e9389ddcccbbce15e484cbdf0a89f27a3c07256/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp#L197-L198
https://github.com/llvm/llvm-project/blob/8e9389ddcccbbce15e484cbdf0a89f27a3c07256/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp#L207-L209
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit of TODO (unsure if it should go in this patch or in the later patches): per SPV_KHR_bfloat extension there are limited number of instructions that can use the type. For example arithmetic instructions like FAdd or FMul can't use bfloat values, hence SPIR-V backend should either emit an error or fall back to FP32 for arithmetic (probably just be calling OpFConvert to FP32 and using the result as the new value).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>; | |
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan,EnvOpenCL]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the suggestion! I have made the change :)
For sure! I’m okay with adding this either in this patch or in a later one. Since this PR already has quite a lot of changes (even after I move the part for supporting bfloat in the SPIR-V backend to a separate PR), I’d slightly prefer handling additional rules in a follow-up patch if there are many more to cover. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also check the bitwidth?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I have added the check for bitwidth here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use isBFloat16Type(MI)
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure, I have made the change. Thank you :)
afd09a6
to
389096a
Compare
389096a
to
8ad95c0
Compare
This PR introduces the support for the SPIR-V extension
SPV_KHR_bfloat16
. This extension extends theOpTypeFloat
instruction to enable the use of bfloat16 types with cooperative matrices and dot products.TODO:
Per the
SPV_KHR_bfloat16
extension, there are a limited number of instructions that can use the bfloat16 type. For example, arithmetic instructions likeFAdd
orFMul
can't operate onbfloat16
values. Therefore, a future patch should be added to either emit an error or fall back to FP32 for arithmetic in cases where bfloat16 must not be used.Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_bfloat16.asciidoc